{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 42\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name      | Type    | Params\n",
      "--------------------------------------\n",
      "0 | net       | SIDLM   | 151 K \n",
      "1 | criterion | MSELoss | 0     \n",
      "2 | metric_r2 | R2Score | 0     \n",
      "--------------------------------------\n",
      "151 K     Trainable params\n",
      "0         Non-trainable params\n",
      "151 K     Total params\n",
      "0.608     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e8cad9d2e48249839e9bcc8b5fbaa65e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/alpha/mambaforge/envs/alpha/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1609: PossibleUserWarning: The number of training batches (42) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "  rank_zero_warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c28e16cea6c24d4f945448b212280f22",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2dc65f1480e64970ad122db19764ad5f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved. New best score: -0.053\n",
      "Epoch 0, global step 42: 'valid_r2' reached -0.05344 (best -0.05344), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=0-step=42.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5bbd6aa0992e4e5f8d58198e5ac2f29e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.291 >= min_delta = 0.0. New best score: 0.237\n",
      "Epoch 1, global step 84: 'valid_r2' reached 0.23730 (best 0.23730), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=1-step=84.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4ceb6280ec894eb0b1d77abc897799e2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.136 >= min_delta = 0.0. New best score: 0.373\n",
      "Epoch 2, global step 126: 'valid_r2' reached 0.37339 (best 0.37339), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=2-step=126.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "59c18af3fa854feaa16bb81e98a5c622",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.111 >= min_delta = 0.0. New best score: 0.484\n",
      "Epoch 3, global step 168: 'valid_r2' reached 0.48418 (best 0.48418), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=3-step=168.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cbdb2fc184a14b2a9ed526b21fe37a7e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.016 >= min_delta = 0.0. New best score: 0.500\n",
      "Epoch 4, global step 210: 'valid_r2' reached 0.50022 (best 0.50022), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=4-step=210.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9d7467d0399b4cf19b7bcd910df165fd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.054 >= min_delta = 0.0. New best score: 0.555\n",
      "Epoch 5, global step 252: 'valid_r2' reached 0.55454 (best 0.55454), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=5-step=252.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4d688edc68994d158ffe2611c3eda744",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.013 >= min_delta = 0.0. New best score: 0.568\n",
      "Epoch 6, global step 294: 'valid_r2' reached 0.56787 (best 0.56787), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=6-step=294.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4dd194591fbf4f3b90c29dccf1e25cca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.019 >= min_delta = 0.0. New best score: 0.587\n",
      "Epoch 7, global step 336: 'valid_r2' reached 0.58691 (best 0.58691), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=7-step=336.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd4812ba005545c38cd1f4b7ddcbdc7a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.012 >= min_delta = 0.0. New best score: 0.599\n",
      "Epoch 8, global step 378: 'valid_r2' reached 0.59864 (best 0.59864), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=8-step=378.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4d1c3175c8994bab90577891b141ed9c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.008 >= min_delta = 0.0. New best score: 0.606\n",
      "Epoch 9, global step 420: 'valid_r2' reached 0.60648 (best 0.60648), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=9-step=420.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dc39ddd5007b479898662a6e6f43817b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.010 >= min_delta = 0.0. New best score: 0.616\n",
      "Epoch 10, global step 462: 'valid_r2' reached 0.61641 (best 0.61641), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=10-step=462.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "05dd6ede12c9441c980bcacb8e3604e1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11, global step 504: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "269e9b2a9ed54061a7cdc8f211ce02c8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12, global step 546: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0d2fd8e60e642058b54cf6b93027da4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.009 >= min_delta = 0.0. New best score: 0.625\n",
      "Epoch 13, global step 588: 'valid_r2' reached 0.62547 (best 0.62547), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=13-step=588.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a259dadb3bd74b21b164412008910dd5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.025 >= min_delta = 0.0. New best score: 0.650\n",
      "Epoch 14, global step 630: 'valid_r2' reached 0.65011 (best 0.65011), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=14-step=630.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ffdf80344c5f451493c0dfc61bcc72d0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15, global step 672: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2e4e13c0dcac48cf9d306a9de8fe4aa5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.013 >= min_delta = 0.0. New best score: 0.663\n",
      "Epoch 16, global step 714: 'valid_r2' reached 0.66345 (best 0.66345), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=16-step=714.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0cdc2438f3964746b72eff8ac14ed015",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17, global step 756: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "034fbf3993444d7db6b0f120951b4b49",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.003 >= min_delta = 0.0. New best score: 0.666\n",
      "Epoch 18, global step 798: 'valid_r2' reached 0.66641 (best 0.66641), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=18-step=798.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6937a6e8711549bc9de2f92b7cd6e66f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19, global step 840: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "227bd1af900044c68e952084bda4d33c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20, global step 882: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "23a1cf8b95274e69bad955c3b3ba1142",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.008 >= min_delta = 0.0. New best score: 0.675\n",
      "Epoch 21, global step 924: 'valid_r2' reached 0.67482 (best 0.67482), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=21-step=924.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "37eb3104a35f40ef87dd30c16e61ca58",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.012 >= min_delta = 0.0. New best score: 0.687\n",
      "Epoch 22, global step 966: 'valid_r2' reached 0.68670 (best 0.68670), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=22-step=966.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "063c515f452547b1bef0c27208e11c63",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.001 >= min_delta = 0.0. New best score: 0.687\n",
      "Epoch 23, global step 1008: 'valid_r2' reached 0.68726 (best 0.68726), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=23-step=1008.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "16d6e26259dd4d9180ef807447acef5a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24, global step 1050: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dcba9fa27b974ea49a86f336fd205439",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25, global step 1092: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "001c1ce9e9f74b45b616c9a86b0b7648",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26, global step 1134: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d851c2e6ec874a608866f25ed0278fef",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.007 >= min_delta = 0.0. New best score: 0.695\n",
      "Epoch 27, global step 1176: 'valid_r2' reached 0.69468 (best 0.69468), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=27-step=1176.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4e46fc7528704d50ab28f15fb487821d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28, global step 1218: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1f2491a908174b8e93953ab24c2426d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.001 >= min_delta = 0.0. New best score: 0.696\n",
      "Epoch 29, global step 1260: 'valid_r2' reached 0.69578 (best 0.69578), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=29-step=1260.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b503121fc5bc43be84fedf226f1434a5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30, global step 1302: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f59abb9010a5401d87aa315b884b2121",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.007 >= min_delta = 0.0. New best score: 0.703\n",
      "Epoch 31, global step 1344: 'valid_r2' reached 0.70271 (best 0.70271), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=31-step=1344.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2dd55ceaee734c209823c6641673f55c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.000 >= min_delta = 0.0. New best score: 0.703\n",
      "Epoch 32, global step 1386: 'valid_r2' reached 0.70275 (best 0.70275), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=32-step=1386.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7071a63fec6a4795833c9baad1f6b4b1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.004 >= min_delta = 0.0. New best score: 0.707\n",
      "Epoch 33, global step 1428: 'valid_r2' reached 0.70706 (best 0.70706), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=33-step=1428.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "16727f4956a04dd887a087392cf4b05a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34, global step 1470: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "607a26ef11424ac598e07dedc3f7ef02",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.002 >= min_delta = 0.0. New best score: 0.709\n",
      "Epoch 35, global step 1512: 'valid_r2' reached 0.70900 (best 0.70900), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=35-step=1512.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "896e9d3f12374bcfbb981a671dff8fbd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.001 >= min_delta = 0.0. New best score: 0.710\n",
      "Epoch 36, global step 1554: 'valid_r2' reached 0.71037 (best 0.71037), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=36-step=1554.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5e47dc45b8614533bac66b4765e00286",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37, global step 1596: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "94ad986f70da4454b575db9b38f98a03",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38, global step 1638: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6bd384ad87684d94896a883f848d8b91",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.000 >= min_delta = 0.0. New best score: 0.710\n",
      "Epoch 39, global step 1680: 'valid_r2' reached 0.71041 (best 0.71041), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=39-step=1680.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c093959f344942d8af41ac8d4be18c3f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.002 >= min_delta = 0.0. New best score: 0.712\n",
      "Epoch 40, global step 1722: 'valid_r2' reached 0.71249 (best 0.71249), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=40-step=1722.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd7e5f5343d0449a9bee23befd7c21c3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.008 >= min_delta = 0.0. New best score: 0.721\n",
      "Epoch 41, global step 1764: 'valid_r2' reached 0.72099 (best 0.72099), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=41-step=1764.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5be9fa546c4349e696c1bd43e69d0848",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42, global step 1806: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "03b015983d55444ebe26ac3e3f2a5fb5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43, global step 1848: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e20c4a37329b45d5a2b384bfdea55dfc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.003 >= min_delta = 0.0. New best score: 0.724\n",
      "Epoch 44, global step 1890: 'valid_r2' reached 0.72437 (best 0.72437), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=44-step=1890.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ca733a626fb44e0fb5ebd894fff9038f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45, global step 1932: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d3ceee47e4864c3f89363ec5921cbfb3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46, global step 1974: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dbe4f5052f75423e86d0d68ced044977",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47, global step 2016: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "773c156463d04f478bfd5aaa4963bed4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48, global step 2058: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cd6e2d75013c40a2b18d27f86bdc34b7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49, global step 2100: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "afb94214b826496893982f2538baad7f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50, global step 2142: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f9e29dab8eef4e5eb6196b8b9980cb37",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 51, global step 2184: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "08d80755e1174a27b65b1ebaa98d7076",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.002 >= min_delta = 0.0. New best score: 0.727\n",
      "Epoch 52, global step 2226: 'valid_r2' reached 0.72664 (best 0.72664), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=52-step=2226.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "35822c1075bf4d9484ea0f6d7700e782",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 53, global step 2268: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e9abd6e239734e52957f008a31a15def",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 54, global step 2310: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fb1b6e71ceaf43bb95a6b3191a011849",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.003 >= min_delta = 0.0. New best score: 0.729\n",
      "Epoch 55, global step 2352: 'valid_r2' reached 0.72922 (best 0.72922), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=55-step=2352.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e7057ac8860d4648840c785040c20a4a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.001 >= min_delta = 0.0. New best score: 0.731\n",
      "Epoch 56, global step 2394: 'valid_r2' reached 0.73066 (best 0.73066), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=56-step=2394.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a79e9e6100194088a7d3fc7d7ab02320",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 57, global step 2436: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2d64293528c64e3bb06bd71104588ec2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 58, global step 2478: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5d753e42d45545d097c54b5111cc59a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.008 >= min_delta = 0.0. New best score: 0.739\n",
      "Epoch 59, global step 2520: 'valid_r2' reached 0.73905 (best 0.73905), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=59-step=2520.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d2c952c860aa4409acdf7d39ae6e32b3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 60, global step 2562: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c595d41fb4954454b0e0532d9f196d27",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 61, global step 2604: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "71a39e8c138b4c52ad9af9717ce7dbd1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 62, global step 2646: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5398019f2586417bab9ad766aab97850",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 63, global step 2688: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5bc38c5dbf2940c7989f3f1a8633201d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 64, global step 2730: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bd6ed47e437f4bf99ae010855ab0e198",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 65, global step 2772: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6814014b45cc4bbc991376e397ef159d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 66, global step 2814: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0298429db1a4ede8ea64e005f4d8643",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric valid_r2 improved by 0.008 >= min_delta = 0.0. New best score: 0.747\n",
      "Epoch 67, global step 2856: 'valid_r2' reached 0.74707 (best 0.74707), saving model to '../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=67-step=2856.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2ca3f6bd9cab4e97adca80bbf2d569bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 68, global step 2898: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "302f9b66c3b4491a88faddda28b99e2f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 69, global step 2940: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dc01b167b241467e8d5a2f0fa128449e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 70, global step 2982: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f77ba332c1b147ab8d1e974ad0fa15a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 71, global step 3024: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "505de2f1320e443e8e566807519f00f1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 72, global step 3066: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "813dcbc53a034ba79548bcc4f1a96292",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 73, global step 3108: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "59152c553d3b43278aef84926a489383",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 74, global step 3150: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b48a43f90a4418aa94e5da8bbe509c9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 75, global step 3192: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "10d20f9ea9e5451fa25f20f4844cc985",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 76, global step 3234: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c8febf96772c4b548f35204a53ced89b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 77, global step 3276: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4ed905829d0f490ba62442e8b8b28893",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 78, global step 3318: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a2eede1c67fc410886a2fe2a580758c0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 79, global step 3360: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1545a85b647b438d8d11d1b68011d51e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 80, global step 3402: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "57c68dcf4a744445af6b912d79c79db4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 81, global step 3444: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f8755479ead249c8b15b8aadf02125c9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 82, global step 3486: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cf8b81a1d8a4460c9852044fb179e75c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 83, global step 3528: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "51d543989b0e40159f7780b16e76f39e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 84, global step 3570: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a7b92034d14a4fc68a778778e5f9e91d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 85, global step 3612: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "68795db46e9a4fe7997c2aadf057ccc8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 86, global step 3654: 'valid_r2' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "90df54b67cd9458e81ff723c2db6aa2e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Monitored metric valid_r2 did not improve in the last 20 records. Best score: 0.747. Signaling Trainer to stop.\n",
      "Epoch 87, global step 3696: 'valid_r2' was not in top 1\n",
      "Restoring states from the checkpoint path at ../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=67-step=2856.ckpt\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "Loaded model weights from checkpoint at ../outputs/tv_sidlm/lightning_logs/version_6/checkpoints/epoch=67-step=2856.ckpt\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9fcdfde38c134ea99feff26bd0d9cfc1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Predicting: 42it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'COUNT': 3581, 'RMSE': 0.09237754542347659, 'MAE': 0.047186988121142026, 'R2': 0.7887953807840447, 'PEARSON': 0.8881415319553783}\n"
     ]
    }
   ],
   "source": [
    "from esidlm.learner.sidlm import SIDLMLearner\n",
    "\n",
    "SIDLM_TRAINING_CONFIG = {\n",
    "\n",
    "    \"global\": {\n",
    "        \"seed\": 42,\n",
    "        \"output_folder\": \"../outputs/tv_sidlm\",\n",
    "    },\n",
    "\n",
    "    \"data\": {\n",
    "        \"train_data\": \"../../geo-tabular-models/data/angle/train.csv\",\n",
    "        \"valid_data\": \"../../geo-tabular-models/data/angle/valid.csv\",\n",
    "        \"test_data\": \"../../geo-tabular-models/data/angle/test.csv\",\n",
    "\n",
    "        \"wide_cols\": [\"season\"],\n",
    "        \"cont_cols\": [\"B1\", \"B2\", \"B3\", \"B4\",\"B5\", \"B_M_1\", \"B_M_2\", \"B_M_3\", \"B_M_4\",\"B_M_5\",\n",
    "                    \"B6\", \"B7\", \"B8\", \"B9\", \"B10\",\"B_M_6\", \"B_M_7\", \"B_M_8\", \"B_M_9\",\"B_M_10\",\n",
    "                    \"M1\", \"M2\", \"M3\", \"M4\", \"M5\",'M6',\"M7\", \"M8\", \"M9\",\"M10\",\"NDVI_F\",'NDVI_B',\n",
    "                    \"SolZ_F\",\"SenZ_F\",\"SCA_F\",\"SolZ_B\",\"SenZ_B\",\"SCA_B\"],\n",
    "        \"cate_cols\": [\"month\"],\n",
    "        \"target_col\":\"Fine_Mode_AOD_500nm\",\n",
    "    },\n",
    "\n",
    "    \"dataloader\": {\n",
    "        \"batch_size\": 256,\n",
    "        \"num_workers\": 4,\n",
    "    },\n",
    "\n",
    "    \"model\": {\n",
    "        \"net\": {\n",
    "            \"d_embed\": 32,\n",
    "            \"d_model\": 256,\n",
    "            \"n_layers\": 2,\n",
    "            \"p_drop\": 0.1,\n",
    "            \"act_fn\": \"relu\"\n",
    "        },\n",
    "        \"optimizer\": {\n",
    "            \"lr\": 0.0005,\n",
    "            \"weight_decay\": 1e-4,\n",
    "        },\n",
    "    },\n",
    "\n",
    "    \"callback\": {\n",
    "        \"model_checkpoint\": {\n",
    "            \"save_top_k\": 1,\n",
    "            \"monitor\": \"valid_r2\",\n",
    "            \"mode\": \"max\",\n",
    "            \"verbose\": True\n",
    "        },\n",
    "        \"early_stopping\": {\n",
    "            \"monitor\": \"valid_r2\",\n",
    "            \"mode\": \"max\",\n",
    "            \"patience\": 20,\n",
    "            \"verbose\": True\n",
    "        }\n",
    "    },\n",
    "\n",
    "    \"trainer\": {\n",
    "        \"max_epochs\": 1000,\n",
    "        \"accelerator\": \"gpu\",\n",
    "        \"devices\": 1,\n",
    "        \"deterministic\": True\n",
    "    }\n",
    "}\n",
    "learner = SIDLMLearner(SIDLM_TRAINING_CONFIG)\n",
    "learner.run_model_training()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>index</th>\n",
       "      <th>B1</th>\n",
       "      <th>B2</th>\n",
       "      <th>B3</th>\n",
       "      <th>B4</th>\n",
       "      <th>B5</th>\n",
       "      <th>B6</th>\n",
       "      <th>B7</th>\n",
       "      <th>B8</th>\n",
       "      <th>B9</th>\n",
       "      <th>...</th>\n",
       "      <th>B_M_3</th>\n",
       "      <th>B_M_4</th>\n",
       "      <th>B_M_5</th>\n",
       "      <th>B_M_6</th>\n",
       "      <th>B_M_7</th>\n",
       "      <th>B_M_8</th>\n",
       "      <th>B_M_9</th>\n",
       "      <th>B_M_10</th>\n",
       "      <th>NDVI_F</th>\n",
       "      <th>NDVI_B</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CalTech</td>\n",
       "      <td>0.280219</td>\n",
       "      <td>0.140441</td>\n",
       "      <td>0.085302</td>\n",
       "      <td>0.201208</td>\n",
       "      <td>0.140330</td>\n",
       "      <td>0.298553</td>\n",
       "      <td>0.159098</td>\n",
       "      <td>0.126776</td>\n",
       "      <td>0.320083</td>\n",
       "      <td>...</td>\n",
       "      <td>0.007900</td>\n",
       "      <td>0.026759</td>\n",
       "      <td>0.012998</td>\n",
       "      <td>0.087576</td>\n",
       "      <td>0.036512</td>\n",
       "      <td>0.016422</td>\n",
       "      <td>0.077901</td>\n",
       "      <td>0.032847</td>\n",
       "      <td>0.404546</td>\n",
       "      <td>0.432592</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>NEON_SJER</td>\n",
       "      <td>0.290705</td>\n",
       "      <td>0.272996</td>\n",
       "      <td>0.222365</td>\n",
       "      <td>0.257017</td>\n",
       "      <td>0.288366</td>\n",
       "      <td>0.330426</td>\n",
       "      <td>0.244312</td>\n",
       "      <td>0.239725</td>\n",
       "      <td>0.306095</td>\n",
       "      <td>...</td>\n",
       "      <td>0.096544</td>\n",
       "      <td>0.042156</td>\n",
       "      <td>0.019331</td>\n",
       "      <td>0.105384</td>\n",
       "      <td>0.105495</td>\n",
       "      <td>0.069892</td>\n",
       "      <td>0.038315</td>\n",
       "      <td>0.051702</td>\n",
       "      <td>0.072284</td>\n",
       "      <td>0.121598</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2 rows × 53 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       index        B1        B2        B3        B4        B5        B6  \\\n",
       "0    CalTech  0.280219  0.140441  0.085302  0.201208  0.140330  0.298553   \n",
       "1  NEON_SJER  0.290705  0.272996  0.222365  0.257017  0.288366  0.330426   \n",
       "\n",
       "         B7        B8        B9  ...     B_M_3     B_M_4     B_M_5     B_M_6  \\\n",
       "0  0.159098  0.126776  0.320083  ...  0.007900  0.026759  0.012998  0.087576   \n",
       "1  0.244312  0.239725  0.306095  ...  0.096544  0.042156  0.019331  0.105384   \n",
       "\n",
       "      B_M_7     B_M_8     B_M_9    B_M_10    NDVI_F    NDVI_B  \n",
       "0  0.036512  0.016422  0.077901  0.032847  0.404546  0.432592  \n",
       "1  0.105495  0.069892  0.038315  0.051702  0.072284  0.121598  \n",
       "\n",
       "[2 rows x 53 columns]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_csv(\"../../geo-tabular-models/data/angle/test.csv\")\n",
    "df.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "alpha",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
